{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scikit-learn DBSCAN OD Clustering\n",
    "\n",
    "<img align=\"right\" src=\"https://anitagraser.github.io/movingpandas/assets/img/movingpandas.png\">\n",
    "\n",
    "This demo requires scikit-learn which is not a dependency of MovingPandas."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from geopandas import GeoDataFrame, read_file\n",
    "from shapely.geometry import Point, LineString, Polygon, MultiPoint\n",
    "from datetime import datetime, timedelta\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.cluster import DBSCAN\n",
    "from geopy.distance import great_circle\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import movingpandas as mpd\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ship movements (AIS data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = read_file('../data/ais.gpkg')\n",
    "df['t'] = pd.to_datetime(df['Timestamp'], format='%d/%m/%Y %H:%M:%S')\n",
    "df = df.set_index('t')\n",
    "df = df[df.SOG>0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MIN_LENGTH = 100 # meters\n",
    "TRIP_ID = 'MMSI'\n",
    "traj_collection = mpd.TrajectoryCollection(df, TRIP_ID, min_length=MIN_LENGTH)\n",
    "print(\"Finished creating {} trajectories\".format(len(traj_collection)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trips = mpd.ObservationGapSplitter(traj_collection).split(gap=timedelta(minutes=5))\n",
    "print(\"Extracted {} individual trips from {} continuous vessel tracks\".format(len(trips), len(traj_collection)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "KMS_PER_RADIAN = 6371.0088\n",
    "EPSILON = 0.1 / KMS_PER_RADIAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trips.get_start_locations()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_od_line(row, od_clusters):\n",
    "    return LineString([od_clusters.loc[row['od'][0]].geometry, od_clusters.loc[row['od'][-1]].geometry])\n",
    "\n",
    "def get_centermost_point(cluster):\n",
    "    centroid = (MultiPoint(cluster).centroid.x, MultiPoint(cluster).centroid.y)\n",
    "    centermost_point = min(cluster, key=lambda point: great_circle(point, centroid).m)\n",
    "    return Point(tuple(centermost_point)[1], tuple(centermost_point)[0])\n",
    "\n",
    "def extract_od_gdf(trips):\n",
    "    origins = trips.get_start_locations()\n",
    "    origins['type'] = '0'\n",
    "    origins['traj_id'] = [trip.id for trip in trips]\n",
    "    destinations = trips.get_end_locations()\n",
    "    destinations['type'] = '1'\n",
    "    destinations['traj_id'] = [trip.id for trip in trips]\n",
    "    od = origins.append(destinations)\n",
    "    od['lat'] = od.geometry.y\n",
    "    od['lon'] = od.geometry.x\n",
    "    return od\n",
    "\n",
    "def dbscan_cluster_ods(od_gdf, eps):\n",
    "    matrix = od_gdf[['lat', 'lon']].to_numpy()\n",
    "    db = DBSCAN(eps=eps, min_samples=1, algorithm='ball_tree', metric='haversine').fit(np.radians(matrix))\n",
    "    cluster_labels = db.labels_\n",
    "    num_clusters = len(set(cluster_labels))\n",
    "    clusters = pd.Series([matrix[cluster_labels == n] for n in range(num_clusters)])\n",
    "    return cluster_labels, clusters\n",
    "\n",
    "def extract_od_clusters(od_gdf, eps):    \n",
    "    cluster_labels, clusters = dbscan_cluster_ods(od_gdf, eps)\n",
    "    od_gdf['cluster'] = cluster_labels\n",
    "    od_by_cluster = pd.DataFrame(od_gdf).groupby(['cluster'])\n",
    "    clustered = od_by_cluster['ShipType'].unique().to_frame(name='types')\n",
    "    clustered['n'] = od_by_cluster.size()\n",
    "    clustered['symbol_size'] = clustered['n']*10 # for visualization purposes\n",
    "    clustered['sog'] = od_by_cluster['SOG'].mean()\n",
    "    clustered['geometry'] = clusters.map(get_centermost_point) \n",
    "    clustered = clustered[clustered['n']>0].sort_values(by='n', ascending=False)\n",
    "    return clustered\n",
    "    \n",
    "def extract_od_matrix(trips, eps, directed=True):\n",
    "    od_gdf = extract_od_gdf(trips)\n",
    "    matrix_nodes = extract_od_clusters(od_gdf, eps)\n",
    "    od_by_traj_id = pd.DataFrame(od_gdf).sort_values(['type']).groupby(['traj_id']) # Groupby preserves the order of rows within each group.\n",
    "    od_by_traj_id = od_by_traj_id['cluster'].unique().to_frame(name='clusters')  # unique() preserves input order according to https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.unique.html\n",
    "    if directed: \n",
    "        od_matrix = od_by_traj_id.groupby(od_by_traj_id['clusters'].apply(tuple)).count().rename({'clusters':'n'}, axis=1)\n",
    "    else:\n",
    "        od_matrix = od_by_traj_id.groupby(od_by_traj_id['clusters'].apply(sorted).apply(tuple)).count().rename({'clusters':'n'}, axis=1)\n",
    "    od_matrix['od'] = od_matrix.index\n",
    "    od_matrix['geometry'] = od_matrix.apply(lambda x: make_od_line(row=x, od_clusters=matrix_nodes), axis=1 )\n",
    "    return od_matrix, matrix_nodes\n",
    "    \n",
    "od_matrix, matrix_nodes = extract_od_matrix(trips, EPSILON*2, directed=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.max(od_matrix.n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from holoviews import dim\n",
    "\n",
    "( GeoDataFrame(od_matrix).hvplot(title='OD flows', geo=True, tiles='OSM', line_width=dim('n'), alpha=0.5, frame_height=600, frame_width=600) *\n",
    "  GeoDataFrame(matrix_nodes).hvplot(c='sog', size='symbol_size', hover_cols=['cluster', 'n'], geo=True,  cmap='RdYlGn')\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bird migration data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = read_file('../data/gulls.gpkg')\n",
    "df['t'] = pd.to_datetime(df['timestamp'])\n",
    "df = df.set_index('t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "traj_collection = mpd.TrajectoryCollection(df, 'individual-local-identifier', min_length=MIN_LENGTH)     \n",
    "print(\"Finished creating {} trajectories\".format(len(traj_collection)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trips = mpd.TemporalSplitter(traj_collection).split(mode='month')\n",
    "print(\"Extracted {} individual trips from {} continuous tracks\".format(len(trips), len(traj_collection)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPSILON = 100 / KMS_PER_RADIAN\n",
    "\n",
    "def extract_od_gdf(trips):\n",
    "    origins = trips.get_start_locations()\n",
    "    origins['type'] = '0'\n",
    "    origins['traj_id'] = [trip.id for trip in trips]\n",
    "    destinations = trips.get_end_locations()\n",
    "    destinations['type'] = '1'\n",
    "    destinations['traj_id'] = [trip.id for trip in trips]\n",
    "    od = origins.append(destinations)\n",
    "    od['lat'] = od.geometry.y\n",
    "    od['lon'] = od.geometry.x\n",
    "    return od\n",
    "\n",
    "def extract_od_clusters(od_gdf, eps):    \n",
    "    cluster_labels, clusters = dbscan_cluster_ods(od_gdf, eps)\n",
    "    od_gdf['cluster'] = cluster_labels\n",
    "    od_by_cluster = pd.DataFrame(od_gdf).groupby(['cluster'])\n",
    "    clustered = od_by_cluster.size().to_frame(name='n')\n",
    "    clustered['geometry'] = clusters.map(get_centermost_point) \n",
    "    clustered = clustered[clustered['n']>0].sort_values(by='n', ascending=False)\n",
    "    return clustered\n",
    "\n",
    "od_matrix, matrix_nodes = extract_od_matrix(trips, EPSILON, directed=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "( GeoDataFrame(od_matrix).hvplot(title='OD flows', geo=True, tiles='OSM', hover_cols=['n'], line_width=dim('n')*0.05, alpha=0.5, frame_height=600, frame_width=600) *\n",
    "  GeoDataFrame(matrix_nodes).hvplot(c='n', size=dim('n')*0.1, hover_cols=['cluster', 'n'], geo=True,  cmap='RdYlGn')\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing OD flows and TrajectoryCollectionAggregator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aggregator = mpd.TrajectoryCollectionAggregator(trips, max_distance=1000000, min_distance=100000, min_stop_duration=timedelta(minutes=5))\n",
    "\n",
    "flows = aggregator.get_flows_gdf()\n",
    "clusters = aggregator.get_clusters_gdf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "( flows.hvplot(title='Generalized aggregated trajectories', geo=True, hover_cols=['weight'], line_width='weight', alpha=0.5, color='#1f77b3', tiles='OSM', frame_height=600, frame_width=400) * \n",
    "  clusters.hvplot(geo=True, color='red', size='n') \n",
    "+\n",
    " GeoDataFrame(od_matrix).hvplot(title='OD flows', geo=True, tiles='OSM', hover_cols=['n'], line_width=dim('n')*0.05, alpha=0.5, frame_height=600, frame_width=400) *\n",
    "  GeoDataFrame(matrix_nodes).hvplot(c='n', size=dim('n')*0.1, hover_cols=['cluster', 'n'], geo=True,  cmap='RdYlGn')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
